//
// Copyright 2022 The Project Oak Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//

//! Utilities for building interrupt handlers that need to modify the register
//! contents, such as a #VC handler that handles VMM communication exceptions
//! generated by CPUID calls.
//!
//! The "x86-interrupt" calling convention does not allow modification of any
//! general purpose registers, so the #VC handler requires a different approach.
//!
//! Example usage:
//!
//! ```rust
//! #![feature(naked_functions)]
//! #![feature(asm_sym)]
//!
//! use oak_sev_guest::interrupts::{
//!     mutable_interrupt_handler_with_error_code, MutableInterruptStackFrame,
//! };
//!
//! mutable_interrupt_handler_with_error_code!(
//!     unsafe fn vmm_communication_exception_handler(
//!         stack_frame: &mut MutableInterruptStackFrame,
//!         error_code: u64,
//!     ) {
//!         // Provide fake values for CPUID.
//!         match error_code {
//!             // CPUID was intercepted.
//!             0x72 => {
//!                 stack_frame.rax = 0;
//!                 stack_frame.rbx = 0;
//!                 stack_frame.rcx = 0;
//!                 stack_frame.rdx = 0;
//!                 // Advance RIP by the length of the CPUID instruction.
//!                 stack_frame.rip += 2u64;
//!             }
//!             _ => panic!("#VC exception {:x?} not handled", error_code),
//!         };
//!     }
//! );
//! ```

use x86_64::VirtAddr;

/// A reference to the interrupt stack frame will be passed as the first
/// argument of the inner handler function.
///
/// It will point to the backed-up values of registers on the stack. This makes
/// it possible to modify the register values that will be restored when the
/// outer handler returns.
#[repr(C)]
#[derive(Debug)]
pub struct MutableInterruptStackFrame {
    /// The backed-up value of the RAX register.
    pub rax: u64,
    /// The backed-up value of the RBX register.
    pub rbx: u64,
    /// The backed-up value of the RCX register.
    pub rcx: u64,
    /// The backed-up value of the RDX register.
    pub rdx: u64,
    /// The backed-up value of the RDI register.
    pub rdi: u64,
    /// The backed-up value of the RSI register.
    pub rsi: u64,
    /// The backed-up value of the instruction pointer when the interrupt
    /// happened. Whether this points to the instruction that caused the
    /// interrupt or just after it depends on the interrupt. For a #VC
    /// exception this points to the instruction that caused the exception
    /// (e.g. CPUID).
    pub rip: VirtAddr,
    /// The backed-up value of the code segment selector.
    pub cs: u64,
    /// The backed-up value CPU's flags register.
    pub rflags: u64,
    /// The backed-up value of the stack pointer.
    pub rsp: VirtAddr,
    /// The backed-up value of the stack segment.
    pub ss: u64,
}

#[macro_export]
macro_rules! mutable_interrupt_handler_with_error_code {
    (unsafe fn $name:ident ( $stack_frame:ident : &mut MutableInterruptStackFrame , $error_code:ident : u64 $(,)? ) $code:block) => {
        #[naked]
        unsafe extern "sysv64" fn $name() -> ! {
            use core::arch::asm;

            extern "sysv64" fn inner_function($stack_frame: &mut MutableInterruptStackFrame, $error_code: u64) {
                $code
            }

            asm!(
                // We don't want the error code on the stack. We want the value of RSI on the stack
                // and the error code in RSI, as it will be the second parameter to the inner
                // function.
                "xchg [rsp], rsi",
                // Push other mutable registers in reverse order from the struct definition.
                "push rdi",
                "push rdx",
                "push rcx",
                "push rbx",
                "push rax",
                // We are now at the stack location which we want to pass as the first parameter to
                // the inner function, so move the current stack pointer into RDI.
                "mov rdi, rsp",
                // Back up remaining scratch registers.
                "push r8",
                "push r9",
                "push r10",
                "push r11",
                // Make sure the stack is 16-byte aligned.
                "sub rsp, 8",

                // Back up the AVX registers.
                // TODO(#3329): Update interrupt handler macro to support AVX, SSE or neither.
                "sub rsp, 16*32",
                "vmovups [rsp + 0*32], YMM0",
                "vmovups [rsp + 1*32], YMM1",
                "vmovups [rsp + 2*32], YMM2",
                "vmovups [rsp + 3*32], YMM3",
                "vmovups [rsp + 4*32], YMM4",
                "vmovups [rsp + 5*32], YMM5",
                "vmovups [rsp + 6*32], YMM6",
                "vmovups [rsp + 7*32], YMM7",
                "vmovups [rsp + 8*32], YMM8",
                "vmovups [rsp + 9*32], YMM9",
                "vmovups [rsp + 10*32], YMM10",
                "vmovups [rsp + 11*32], YMM11",
                "vmovups [rsp + 12*32], YMM12",
                "vmovups [rsp + 13*32], YMM13",
                "vmovups [rsp + 14*32], YMM14",
                "vmovups [rsp + 15*32], YMM15",

                // Call the inner function with the System V calling convention. Argument 1 is
                // already in RDI and argument 2 in RSI.
                "call  {INNER_ADDRESS}",

                // Restore AVX registers.
                "vmovups YMM0, [rsp + 0*32]",
                "vmovups YMM1, [rsp + 1*32]",
                "vmovups YMM2, [rsp + 2*32]",
                "vmovups YMM3, [rsp + 3*32]",
                "vmovups YMM4, [rsp + 4*32]",
                "vmovups YMM5, [rsp + 5*32]",
                "vmovups YMM6, [rsp + 6*32]",
                "vmovups YMM7, [rsp + 7*32]",
                "vmovups YMM8, [rsp + 8*32]",
                "vmovups YMM9, [rsp + 9*32]",
                "vmovups YMM10, [rsp + 10*32]",
                "vmovups YMM11, [rsp + 11*32]",
                "vmovups YMM12, [rsp + 12*32]",
                "vmovups YMM13, [rsp + 13*32]",
                "vmovups YMM14, [rsp + 14*32]",
                "vmovups YMM15, [rsp + 15*32]",
                "add rsp, 16*32",

                // Undo stack alignment.
                "add rsp, 8",
                // Restore scratch registers.
                "pop r11",
                "pop r10",
                "pop r9",
                "pop r8",
                // Restore potentially modified general-purpose registers from the interrupt stack
                // frame.
                "pop rax",
                "pop rbx",
                "pop rcx",
                "pop rdx",
                "pop rdi",
                "pop rsi",
                // The stack should now be in the original state minus the error code, so return
                // from the handler.
                "iretq",
                INNER_ADDRESS = sym inner_function,
                options(noreturn)
            )
        }
    };
}

pub use mutable_interrupt_handler_with_error_code;
